module sched

import time
import rand

import vcp.rtcom
import vcp
import vcp.mlog
import vcp.vmm
import vcp.iohook
import vcp.iopoller
import vcp.chan1

import vcp.futex
import vcp.coro

#include <sys/mman.h>
#include "@VROOT/schedtls.h"

const vnil = voidptr(0)
const kilobyte = 1024
const millbyte = 1024*1024
const gigabyte = 1024 * 1024 * 1024

// gc not inited
pub fn pre_gc_init() {
    C.printf("sched pregc init\n")
    mut yielder := rtcom.Yielder{}
    yielder.incoro = incoro
    yielder.getcoro = getcoro
    yielder.yield = onyield
    yielder.yield_multi = onyield_multi
    mut resumer := rtcom.Resumer{}
    resumer.resume_one = onresume
    rtcom.pre_gc_init(&yielder, &resumer, voidptr(0))
    iohook.pre_main_init(rtcom.yielder(), voidptr(0))
}

// gc inited
pub fn pre_main_init() {
    println("sched premain init")
    mut yielder := rtcom.yielder()
    mut resumer := rtcom.resumer()
    iopoller.pre_main_init(resumer)
    chan1.pre_main_init(yielder, resumer, voidptr(0))
}
pub fn post_main_deinit() {
    // TODO需要退出coroutine以及procer
}

fn init() {
    iopoller.start()

    // myself
    schedobj = &Schedule{}
    schedobj.init_machines()
}

////// callbacks in under module
fn incoro() int {
    file := @FILE
    line := @LINE
    pos := file.last_index("/") or {0}
    mcid := getmcid()
    if schedobj == vnil {
        // C.printf("%s:%s incoro %d ???\n", file.str+pos+1, line.str, ok)
    }else{
        // mlog.info 中分配内存，导致libgc死锁
        // mlog.info(@FILE, @LINE, "incoro", ok)
        // C.printf("%s:%s incoro %d ???\n", file.str+pos+1, line.str, ok)
    }
    return int(mcid >= 0)
}
fn getcoro() voidptr {
    return voidptr(C.co_sched_grobj)
}

// __thread local
fn machine_set(mcid int, mcobj voidptr) int {
    C.co_sched_mcid = mcid
    C.co_sched_mcobj = mcobj
    return 0
}
fn fiber_set(mcid int, mcobj voidptr) int {
    C.co_sched_grid = mcid
    C.co_sched_grobj = mcobj
    return 0
}
fn getmcid() int { return C.co_sched_mcid }
fn machine_get(mcid &int, mcobj &voidptr) int {
    unsafe {
            *(&int(mcid)) = C.co_sched_mcid
    }
    unsafe {
            *(&voidptr(mcobj)) = voidptr(C.co_sched_mcobj)
    }
    return 0
}
fn fiber_get(mcid &int, mcobj &voidptr) int {
    unsafe {
            *(&int(mcid)) = C.co_sched_grid
    }
    unsafe {
            *(&voidptr(mcobj)) = voidptr(C.co_sched_grobj)
    }
    return 0
}

fn onyield(fdns i64, ytype int) int {
    //mlog.info(@FILE, @LINE, "fdns", fdns, "ytype", ytype)
    mcid := 0
    mcobj := vnil
    machine_get(&mcid, &mcobj)
    grid := 0
    grobjx := vnil
    fiber_get(&grid, &grobjx)
    grobj := &Fiber(grobjx)
    isco := incoro()
    // mlog.info(@FILE, @LINE, fdns, ytype, grid, grobjx, mcid, isco)

    // grobj.savestack()
    if ytype == iohook.YIELD_TYPE_SLEEP {
        iopoller.yieldfd(fdns, ytype, &iopoller.FiberMin(grobjx))
    } else if ytype == iohook.YIELD_TYPE_USLEEP {
        iopoller.yieldfd(fdns, ytype, &iopoller.FiberMin(grobjx))
    } else if ytype == iohook.YIELD_TYPE_CHAN_RECV {
    } else if ytype == iohook.YIELD_TYPE_CHAN_SEND {
    } else if ytype == iohook.YIELD_TYPE_CONNECT ||
        ytype == iohook.YIELD_TYPE_RECV ||
        ytype == iohook.YIELD_TYPE_READ {
        iopoller.yieldfd(fdns, ytype, &iopoller.FiberMin(grobjx))
    } else {
        panic("unimpl $fdns $ytype")
    }
    /*
    match ytype {
        iohook.YIELD_TYPE_CHAN_RECV {}
        else{}
    }
    */
    grobj.swapback(ytype)
    return 0
}
fn onyield_multi (ytype int, cnt int, fds &i64, ytypes &int) int {
    //mlog.info(@FILE, @LINE, "ytype", ytype, "cnt", cnt)
    mcid := 0
    mcobj := vnil
    machine_get(&mcid, &mcobj)
    grid := 0
    grobjx := vnil
    fiber_get(&grid, &grobjx)
    grobj := &Fiber(grobjx)
    isco := incoro()
    // mlog.info(@FILE, @LINE, fdns, ytype, grid, grobjx, mcid, isco)

    // curl
    if ytype == iohook.YIELD_TYPE_UUPOLL {
        for i := 0; i < cnt; i++ {
            // mlog.info(@FILE, @LINE, cnt, fds[i], ytypes[i])
            iopoller.yieldfd(fds[i], ytypes[i], &iopoller.FiberMin(grobjx))
        }
        // x11
    } else if ytype == iohook.YIELD_TYPE_RECVMSG_TIMEOUT {
        for i := 0; i < cnt; i++ {
            // mlog.info(@FILE, @LINE, cnt, fds[i], ytypes[i])
            iopoller.yieldfd(fds[i], ytypes[i], &iopoller.FiberMin(grobjx))
        }
    } else {
        panic("unimpl ytype $ytype cnt $cnt")
    }

    grobj.swapback(ytype)
    return 0
}

fn onresume(grx voidptr, ytype int, grid int, mcid int) {
    //mlog.info(@FILE, @LINE, "resume $ytype", grx, grid, mcid)
    mut grobj := &Fiber(grx)
    assert grobj.grid == grid
    mcobj := grobj.mcobj
    if mcobj.mcid != mcid {
        // mlog.info(@FILE, @LINE)
    }
    assert mcobj.mcid == mcid

    grobj.set_state(coresumed)
    // grobj.state = .coresumed
    grobj.wakecnt ++
    mcobj.wake(ytype)
}

///////////

struct Schedule {
    mut:
    mainmc &Machine = vnil
    osths map[int]&Machine
    mcidno int
    gridno int = 100
    amu &futex.Mutex = futex.newMutex()
}

// const schedobj = &Schedule{}
__global (
    schedobj = &Schedule(0)
    ylder2 = &rtcom.Yielder(0)
    // 这两个似乎都不好与GC配合
    useshrstk = bool(false) // has some problem, disable now
    usemmapstk = bool(false)
)

struct Machine {
    mut:
    mcid int
    futo &futex.Futex = futex.newFutex()
    wakety WakeType
    parkty WakeType
    mainco voidptr
    wakecnt int // 通知的时候，可能接不到

    taskq []&Fiber // not complete initialed
    amu &futex.Mutex = futex.newMutex()
    workq []&Fiber
    rungr &Fiber = vnil
    shrstk voidptr // one shrstk per osthread
}

/*
enum Costate {
    coready
    corunning
    coyielded
    coresumed
    codone
}
*/
const (
    coready = u32(0)
    corunning = u32(1)
    coyielded = u32(2)
    coresumed = u32(3)
    codone = u32(4)
)

struct Fiber {
    grid int
    mut:
    mcid int

    // channel, sudog struct???
    elem voidptr // channel recv/send var addr
    fromgr &Fiber = 0
    channel voidptr // sending channel
    releasetime i64
    param voidptr
    isselect bool

    // state Costate
    state u32
    ytype int
    cofn CoFunc
    stk voidptr = 0 // bottom
    stksz int
    coctx voidptr
    coctx0 voidptr
    ctime time.Time
    yieldtm time.Time
    mcobj &Machine = vnil
    stki vmm.StackInfo
    stksav StackSave
    wakecnt int // 从执行到结束
}

struct StackSave {
    mut:
    stkmem voidptr = vnil
    memsz int = 0
}

struct CoFunc {
    mut:
    this voidptr
    fnptr2 fn(this voidptr, arg voidptr)
    fnptr fn(arg voidptr)
    fnarg voidptr
}
fn (this CoFunc) call() {
    if this.this != vnil {
        this.fnptr2(this.this, this.fnarg)
    }else{
        this.fnptr(this.fnarg)
    }
}

fn newFiber(ff CoFunc) &Fiber {
    id := schedobj.nextgrid()
    nowt := time.now()
    return &Fiber{grid:id, cofn: ff, ctime:nowt}
}
fn (thisp &Fiber) swapback(ytype int) {
    mut this := thisp
    if this.set_state_ifeq(coyielded, corunning) {
        this.ytype = ytype
    }else{
        state := this.get_state()
        //mlog.info(@FILE, @LINE, "not running? some resume me?", state)
        //panic("ooooo")
    }

    if useshrstk { this.savestack() }
    coro.transfer(this.coctx, this.coctx0)
}
fn (thisp &Fiber) set_state_ifeq(new_state u32, expect_state u32) bool {
    rv := C.atomic_compare_exchange_strong_u32(&thisp.state, &expect_state, new_state)
    return rv
}
fn (thisp &Fiber) set_state(state u32) {
    C.atomic_store_u32(&thisp.state, state)
}
fn (thisp &Fiber) get_state() u32 {
    state := C.atomic_load_u32(&thisp.state)
    return state
}

fn (thisp &Fiber) setstki(stksz int) {
    mut this := thisp
    this.stksz = stksz
    // not need care handle, gc as null handle in current thread
    this.stki.membase = this.stk // just gc wanted bottom
    this.stki.stksz = this.stksz
    this.stki.stktop = voidptr(size_t(this.stk) + size_t(this.stksz))
}
fn (thisp &Fiber) destroy() {
    mut this := thisp
    if !useshrstk {
        this.stackguard(false)
        vmm.freegc(this.stk)
        //vmm.freemp(this.stk, this.stksz)
    }else{
        this.stk = vnil
    }
    nowt := time.now()
    mlog.info(@FILE, @LINE, "coexit", this.grid, this.mcid, (nowt-this.ctime).str())
}

/*
corutine stack layout:
top -------------
    |
    |  using
    |
dummy -----------
    |  empty
    |
guard -----------
    |  guard
bottom ----------
*/

fn (thisp &Fiber) savestack() {
    mut this := thisp
    mut stktop := 0
    mut stksz := 0
    mut usesz := 0

    dummy := byte(0) ///// stack pos
    //stktop := voidptr(size_t(this.stk) + size_t(dftstksz)) // this.stki.stktop
    stktop = this.stki.stktop
    stksz = this.stki.stksz
    usesz = int(size_t(voidptr(stktop)) - size_t(voidptr(&dummy)))
    if false {
    mlog.info(@FILE, @LINE, this.grid, "top", stktop, "dummy", voidptr(&dummy),
              "usesz", usesz, "totsz", stksz)
    }
    if usesz < 0 {
        C.printf("usesz<0 %d\n", usesz)
        vcp.abort()
    }
    if usesz > stksz {
        C.printf("usesz>stksz %d>%d\n", usesz,stksz)
        vcp.abort()
    }
    if this.stksav.memsz < usesz {
        //this.stksav.stkmem = malloc(usesz+1)
        this.stksav.stkmem = vmm.mallocuc(usesz+1)
        this.stksav.memsz = usesz
    }
    C.memcpy(this.stksav.stkmem, voidptr(&dummy), usesz)
}
fn (thisp &Fiber) stackoverflow_check() {
    mut this := thisp
    stktop := this.stki.stktop
    stksz := this.stki.stksz
    mut usesz := 0
    mut usept := 0
    dummy := byte(0) // stack pos
    usesz = int(size_t(voidptr(stktop)) - size_t(voidptr(&dummy)))
    usept = usesz * 100 / stksz
    if usept > 75 {
        mlog.info(@FILE, @LINE, "need more stack", thisp.grid, usept)
    }
}
fn overflow_check() {
    grid := 0
    grobjx := vnil
    fiber_get(&grid, &grobjx)
    grobj := &Fiber(grobjx)
    grobj.stackoverflow_check()
}

fn C.mprotect() int
// report when overflow, Cannot access memory at address
fn (thisp &Fiber) stackguard(on bool) {
    if true {
        return
    }
    addr := thisp.stk
    guardsz := C.sysconf(C._SC_PAGESIZE) // must round to multiple of pagesize
    if on {
        C.memset(addr, 0, guardsz)
        rv := C.mprotect(addr, guardsz, C.PROT_READ)
        assert rv == 0
        // test 1
        okaddr := voidptr(size_t(addr) + size_t(guardsz))
        C.memcpy(okaddr, &rv, sizeof(rv))
        // below should segfault
        // guardaddr := voidptr(size_t(addr) + size_t(1024*4-1))
        // C.memcpy(guardaddr, &rv, sizeof(rv))
        // C.memcpy(addr, &rv, sizeof(rv))
    }else{
        // rv := C.mprotect(addr, guardsz, C.PROT_READ | C.PROT_WRITE|C.PROT_EXEC)
        rv := C.mprotect(addr, guardsz, C.PROT_READ | C.PROT_WRITE)
        assert rv == 0
    }
}

/////////////////////////////
fn (thisp &Schedule) init_machines() {
    mut this := thisp
    for i := 1; i <= 3 ; i++ {
        mut m := &Machine{}
        m.mcid = i
        if useshrstk {
            m.shrstk = vmm.mallocuc(dftstksz)
            //m.shrstk = vmm.mallocmp(dftstksz)
        }
        this.osths[i] = m
        go coruner_proc(m)
    }
    mut m := &Machine{}
    m.mcid = -2
    this.mainmc = m
    go coctrl_proc(m)
}

fn comainfp(argx voidptr) {
    mut co := &Fiber(argx)
    co.cofn.call()
    mlog.info(@FILE, @LINE, "cofn done", co.grid, co.mcid)
    // co.state = .codone
    co.set_state(codone)
    coro.transfer(co.coctx, co.coctx0)
}

// not complete initialed fibers
// taskq manager
fn (thisp &Machine) addnew(gr &Fiber) {
    mut this := thisp
    mut gr2 := gr
    gr2.mcid = this.mcid
    this.amu.mlock()
    this.taskq << gr
    this.amu.munlock()
}
fn (thisp &Machine) getnew() &Fiber {
    mut this := thisp
    mut gr := &Fiber(vnil)
    this.amu.mlock()
    if this.taskq.len > 0 {
        gr = this.taskq[0]
        this.taskq.delete(0)
    }
    this.amu.munlock()
    return gr
}
fn (thisp &Machine) cntnew() int {
    mut this := thisp
    mut len := 0
    this.amu.mlock()
    len = this.taskq.len
    this.amu.munlock()
    return len
}

// workq manager
fn (thisp &Machine) append(gr &Fiber) {
    mut this := thisp
    this.amu.mlock()
    this.workq << gr
    this.amu.munlock()
}
fn (thisp &Machine) popnext() &Fiber {
    mut this := thisp
    mut gr := &Fiber(vnil)
    mut idx := -1
    mut haswakecnt := 0
    this.amu.mlock()
    for i := this.workq.len-1; i >= 0; i-- {
        tr := this.workq[i]
        if tr == vnil { C.abort() }
        state := tr.get_state()
        if state == coready || state == coresumed {
            if gr == vnil {
                gr = tr
                idx = i
                //break
            }
        }else if tr.wakecnt > 0 {
            haswakecnt ++
        }
    }
    if gr != vnil { this.workq.delete(idx) }
    this.amu.munlock()
    if haswakecnt>0 { // so many this case???
        // mlog.info(@FILE, @LINE, "wakecnt>0 but state notok", haswakecnt)
    }
    return gr
}


fn (thisp& Machine) corofy(grp &Fiber) voidptr {
    mut co := grp
    co.mcobj = thisp
    co.coctx0 = thisp.mainco
    coctx := coro.newctx()
    co.coctx = coctx
    coro.create(coctx, comainfp, co, co.stk, co.stksz)
    return coctx
}
fn coruner_proc(arg &Machine) {
    mut mysi := vmm.get_my_stackbottom()
    machine_set(arg.mcid, arg)
    mymcid := 0
    mymcobjx := vnil
    machine_get(&mymcid, &mymcobjx)
    mlog.info(@FILE, @LINE, mymcid, mymcobjx, voidptr(&mymcid))
    mut mymcobj := &Machine(mymcobjx)
    mut myfuto := mymcobj.futo
    mut mainco := coro.newctx()
    mymcobj.mainco = mainco
    //mymcobj.shrstk = vmm.mallocuc(dftstksz)
    coro.create(mainco, vnil, vnil, vnil, 0)

    for {
        mut needpark := true
        mut grobj := &Fiber(vnil)
        grobj = mymcobj.popnext()
        if grobj != vnil {
            needpark = false
        }else{
            if mymcobj.cntnew() > 0 {
                needpark = false
            }
        }

        if needpark {
            mymcobj.wakecnt = 0
            myfuto.park()
            // mlog.info(@FILE, @LINE, mymcid, "waked", mymcobj.wakecnt, mymcobj.taskq.len)
        }

        for {
            mut gr := mymcobj.getnew()
            if gr == vnil { break }
            // mlog.info(@FILE, @LINE, "left", mymcobj.taskq.len)
            if useshrstk {
                gr.stk = mymcobj.shrstk
                gr.setstki(dftstksz)
            }
            coctx := mymcobj.corofy(gr)
            mymcobj.append(gr)
        }
        if grobj == vnil {
            grobj = mymcobj.popnext()
        }
        if grobj != vnil {
            mymcobj.rungr = grobj
            grobj.set_state(corunning)
            grobj.wakecnt = 0
            //C.printf("hhhhh %d\n", 540)
            if useshrstk {
                C.memset(mymcobj.shrstk, 0, dftstksz-5000)
                //grobj.stk = mymcobj.shrstk
                //C.printf("copystkback (%p, %p, %d)\n",
                  //       grobj.stk, grobj.stksav.stkmem, grobj.stksav.memsz)
                if grobj.stksav.stkmem != vnil {
                C.memcpy(mymcobj.shrstk, grobj.stksav.stkmem, grobj.stksav.memsz)
                }
            }
            fiber_set(grobj.grid, grobj)
            vmm.set_stackbottom(grobj.stki) // 没有这个设置根本就不回收 coroutine运行时分配的内存
            coro.transfer(mainco, grobj.coctx)
            vmm.set_stackbottom(mysi) // 没有这个设置根本就不回收 coroutine运行时分配的内存
            fiber_set(-1, vnil)
            if useshrstk {
                //grobj.stk = vnil
            }
            mymcobj.rungr = vnil
            // mlog.info(@FILE, @LINE, "swap back", grobj.grid,
               //       grobj.state.str(), grobj.ytype, grobj.wakecnt)
            if grobj.get_state() == codone {
                // destroy
                grobj.destroy()
            }else{
                mymcobj.append(grobj)
            }
        }
    }

    mlog.info(@FILE, @LINE, mymcid, "exit")
}

// random one
fn (thisp &Schedule) pickmc() &Machine {
    mut mcidx := rand.intn(thisp.osths.len)
    assert mcidx >= 0
    // mlog.info(@FILE, @LINE, mcidx)
    mut cnter := -1
    mut mcobj := &Machine(vnil)
    mut mcid := 0
    for _, obj in thisp.osths {
        cnter++
        if cnter == mcidx {
            mcobj = obj
            break
        }
    }
    return mcobj
}

fn coctrl_proc(arg &Machine) {
    mymcid := arg.mcid
    mymcobjx := voidptr(arg)
    mut mymcobj := &Machine(arg)
    mut myfuto := mymcobj.futo
    mut scheder := schedobj

    mlog.info(@FILE, @LINE, mymcid, mymcobjx)
    for {
        myfuto.park()
        mlog.info(@FILE, @LINE, mymcid, "waked", mymcobj.wakety.str())
        if mymcobj.cntnew() == 0 {
            continue
        }
        for {
            mut curgr := mymcobj.getnew()
            if curgr == vnil { break }

            mut mcobj := scheder.pickmc()
            mut mcid := mcobj.mcid
            mlog.info(@FILE, @LINE, "gr#${curgr.grid} move to mc#$mcid")
            mcobj.addnew(curgr)
            mcobj.wake(.newtask)
        }
    }

    mlog.info(@FILE, @LINE, mymcid, "exit")
}

fn (thisp &Schedule) nextgrid() int {
    mut this := thisp
    mut no := 0
    this.amu.mlock()
    no = this.gridno
    this.gridno = this.gridno + 1
    this.amu.munlock()
    // mlog.info(@FILE, @LINE, "nextid", no)
    return no
}

pub const dftstksz = 128*1024

pub fn post(f voidptr, arg voidptr) {
    post2(vnil, f, arg)
}
pub fn post2(this voidptr, f voidptr, arg voidptr) {
    post3(this, f, arg, dftstksz)
}
pub fn post3(this voidptr, f voidptr, arg voidptr, stksz int) {
    mut sch := schedobj
    stksz2 := if stksz <= 0 { dftstksz } else { stksz }
    ff := CoFunc{this, f, f, arg}
    mut gr := newFiber(ff)
    if !useshrstk {
        gr.stk = vmm.mallocuc(stksz2)
        //gr.stk = vmm.mallocmp(stksz2)
        gr.setstki(stksz2)
        gr.stackguard(true)
    }
    sch.mainmc.addnew(gr)
    sch.mainmc.wake(.newtask)
}

enum WakeType {
    unknown
    newtask
    iocanread
    iocanwrite
    ioconnected
    ioclosed
    timerto
}
fn (thisp &Machine) wake(ty WakeType) {
    mut this := thisp
    this.wakety = ty
    this.wakecnt ++
    thisp.futo.wake()
}

// Features
// [x] stack guard
// [ ] shared stack, less memory usage
// [ ] dynamic increase stack size, sigaltstack
// [ ] mmap

/// 共享栈协程保存与恢复原理，https://blog.csdn.net/liushengxi_root/article/details/85114692
/// 共享栈的坑，栈变量地址重复问题，https://masutangu.com/2018/12/10/libco-share-stack/#%E5%85%B1%E4%BA%AB%E6%A0%88%E6%A8%A1%E5%BC%8F%E9%9A%90%E8%97%8F%E7%9A%84%E5%9D%91

// 类似项目
// https://github.com/idealvin/co
// https://github.com/owt5008137/libcopp
// https://github.com/SasLuca/libco/tree/master/source
